{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# C host"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import tvm\n",
    "import tvm.testing\n",
    "\n",
    "from tvm import te\n",
    "from tvm.contrib import utils\n",
    "from tvm.script import tir as T, ir as I\n",
    "\n",
    "import numpy as np"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## add"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "nn = 1024\n",
    "n = tvm.runtime.convert(nn)\n",
    "A = te.placeholder((n,), name=\"A\")\n",
    "B = te.placeholder((n,), name=\"B\")\n",
    "C = te.compute(A.shape, lambda *i: A(*i) + B(*i), name=\"C\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "mhost = tvm.compile(\n",
    "    tvm.IRModule.from_expr(\n",
    "        te.create_prim_func([A, B, C]).with_attr(\"global_symbol\", \"test_fadd\")\n",
    "    ),\n",
    "    target=\"c\",\n",
    ")\n",
    "temp = utils.tempdir()\n",
    "path_dso = temp.relpath(\"temp.so\")\n",
    "mhost.export_library(path_dso)\n",
    "m = tvm.runtime.load_module(path_dso)\n",
    "fadd = m[\"test_fadd\"]\n",
    "dev = tvm.cpu(0)\n",
    "# launch the kernel.\n",
    "n = nn\n",
    "a = tvm.nd.array(np.random.uniform(size=n).astype(A.dtype), dev)\n",
    "b = tvm.nd.array(np.random.uniform(size=n).astype(B.dtype), dev)\n",
    "c = tvm.nd.array(np.zeros(n, dtype=C.dtype), dev)\n",
    "fadd(a, b, c)\n",
    "tvm.testing.assert_allclose(c.numpy(), a.numpy() + b.numpy())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## reinterpret"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "nn = 1024\n",
    "n = tvm.runtime.convert(nn)\n",
    "A = te.placeholder((n,), name=\"A\", dtype=\"int32\")\n",
    "B = te.compute(\n",
    "    A.shape, lambda *i: tvm.tir.call_intrin(\"float32\", \"tir.reinterpret\", 2 + A(*i)), name=\"B\"\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "mhost = tvm.compile(\n",
    "    tvm.IRModule.from_expr(\n",
    "        te.create_prim_func([A, B]).with_attr(\"global_symbol\", \"test_reinterpret\")\n",
    "    ),\n",
    "    target=\"c\",\n",
    ")\n",
    "temp = utils.tempdir()\n",
    "path_dso = temp.relpath(\"temp.so\")\n",
    "mhost.export_library(path_dso)\n",
    "m = tvm.runtime.load_module(path_dso)\n",
    "fadd = m[\"test_reinterpret\"]\n",
    "dev = tvm.cpu(0)\n",
    "n = nn\n",
    "a = tvm.nd.array(np.random.randint(-(2**30), 2**30, size=n).astype(A.dtype), dev)\n",
    "b = tvm.nd.array(np.zeros(n, dtype=B.dtype), dev)\n",
    "fadd(a, b)\n",
    "tvm.testing.assert_allclose(b.numpy(), (2 + a.numpy()).view(\"float32\"))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## ceil"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "nn = 1024\n",
    "n = tvm.runtime.convert(nn)\n",
    "A = te.placeholder((n,), name=\"A\", dtype=\"float32\")\n",
    "B = te.compute(A.shape, lambda *i: tvm.tir.call_intrin(\"float32\", \"tir.ceil\", A(*i)), name=\"B\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "mhost = tvm.compile(\n",
    "    tvm.IRModule.from_expr(\n",
    "        te.create_prim_func([A, B]).with_attr(\"global_symbol\", \"test_ceil\")\n",
    "    ),\n",
    "    target=\"c\",\n",
    ")\n",
    "temp = utils.tempdir()\n",
    "path_dso = temp.relpath(\"temp.so\")\n",
    "mhost.export_library(path_dso)\n",
    "m = tvm.runtime.load_module(path_dso)\n",
    "fceil = m[\"test_ceil\"]\n",
    "dev = tvm.cpu(0)\n",
    "n = nn\n",
    "a = tvm.nd.array(np.random.rand(n).astype(A.dtype), dev)\n",
    "b = tvm.nd.array(np.zeros(n, dtype=B.dtype), dev)\n",
    "fceil(a, b)\n",
    "tvm.testing.assert_allclose(b.numpy(), (np.ceil(a.numpy()).view(\"float32\")))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## floor"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "nn = 1024\n",
    "n = tvm.runtime.convert(nn)\n",
    "A = te.placeholder((n,), name=\"A\", dtype=\"float32\")\n",
    "B = te.compute(A.shape, lambda *i: tvm.tir.call_intrin(\"float32\", \"tir.floor\", A(*i)), name=\"B\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "mhost = tvm.compile(\n",
    "    tvm.IRModule.from_expr(\n",
    "        te.create_prim_func([A, B]).with_attr(\"global_symbol\", \"test_floor\")\n",
    "    ),\n",
    "    target=\"c\",\n",
    ")\n",
    "temp = utils.tempdir()\n",
    "path_dso = temp.relpath(\"temp.so\")\n",
    "mhost.export_library(path_dso)\n",
    "m = tvm.runtime.load_module(path_dso)\n",
    "ffloor = m[\"test_floor\"]\n",
    "dev = tvm.cpu(0)\n",
    "n = nn\n",
    "a = tvm.nd.array(np.random.rand(n).astype(A.dtype), dev)\n",
    "b = tvm.nd.array(np.zeros(n, dtype=B.dtype), dev)\n",
    "ffloor(a, b)\n",
    "tvm.testing.assert_allclose(b.numpy(), (np.floor(a.numpy()).view(\"float32\")))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## round"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "nn = 1024\n",
    "n = tvm.runtime.convert(nn)\n",
    "A = te.placeholder((n,), name=\"A\", dtype=\"float32\")\n",
    "B = te.compute(A.shape, lambda *i: tvm.tir.call_intrin(\"float32\", \"tir.round\", A(*i)), name=\"B\")\n",
    "\n",
    "mhost = tvm.compile(\n",
    "    tvm.IRModule.from_expr(\n",
    "        te.create_prim_func([A, B]).with_attr(\"global_symbol\", \"test_round\")\n",
    "    ),\n",
    "    target=\"c\",\n",
    ")\n",
    "temp = utils.tempdir()\n",
    "path_dso = temp.relpath(\"temp.so\")\n",
    "mhost.export_library(path_dso)\n",
    "m = tvm.runtime.load_module(path_dso)\n",
    "fround = m[\"test_round\"]\n",
    "dev = tvm.cpu(0)\n",
    "n = nn\n",
    "a = tvm.nd.array(np.random.rand(n).astype(A.dtype), dev)\n",
    "b = tvm.nd.array(np.zeros(n, dtype=B.dtype), dev)\n",
    "fround(a, b)\n",
    "tvm.testing.assert_allclose(b.numpy(), (np.round(a.numpy()).view(\"float32\")))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 子程序调用"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "通过 `private=True` 标记的 subroutine 不应出现在导出函数列表"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "@I.ir_module\n",
    "class mod:\n",
    "    @T.prim_func\n",
    "    def main(A: T.Buffer(1, dtype=\"float32\")):\n",
    "        mod.subroutine(A.data)\n",
    "\n",
    "    @T.prim_func(private=True)\n",
    "    def subroutine(A_data: T.handle(\"float32\")):\n",
    "        A = T.decl_buffer(1, dtype=\"float32\", data=A_data)\n",
    "        A[0] = 42.0\n",
    "\n",
    "built = tvm.tir.build(mod, target=\"c\")\n",
    "\n",
    "func_names = list(built[\"get_func_names\"]())\n",
    "assert (\n",
    "    \"main\" in func_names\n",
    "), \"Externally exposed functions should be listed in available functions.\"\n",
    "assert (\n",
    "    \"subroutine\" not in func_names\n",
    "), \"Internal function should not be listed in available functions.\"\n",
    "\n",
    "source = built.get_source()\n",
    "assert (\n",
    "    source.count(\"main(void*\") == 2\n",
    "), \"Expected two occurrences, for forward-declaration and definition\"\n",
    "assert (\n",
    "    source.count(\"subroutine(float*\") == 2\n",
    "), \"Expected two occurrences, for forward-declaration and definition\"\n",
    "assert (\n",
    "    source.count(\"subroutine(\") == 3\n",
    "), \"Expected three occurrences, for forward-declaration, definition, and call from main.\"\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "ai",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
